from datetime import datetime
import time
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier, HistGradientBoostingClassifier
from sklearn import metrics
from evidently.report import Report
from evidently.metrics import DataDriftTable, DatasetDriftMetric
import shap
import xgboost as xgb
import pickle
import graphviz
from tableone import TableOne, load_dataset
from IPython.display import Latex
from counterfactuals import *
C:\Users\boris\anaconda3\envs\INNO\lib\site-packages\tqdm\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
shap.initjs()
pd.set_option('display.max_columns', None)
warnings.filterwarnings('ignore')
XAI-trust overzicht explanations¶
Dit notebook dient als toelichting bij alle explanations die het XAI-trust project gemaakt heeft aan de hand van een model dat sepsis-geassocieerd delirium voorspelt. Het is niet nodig om de paper over het model te lezen, maar kan wellicht wel context bieden: https://doi.org/10.1038/s41598-023-38650-4
In de markdown cellen staat informatie over hoe de explanation afgelezen dient te worden. Sommige stukjes code zijn voorzien van comments die hulp kunnen bieden bij de implementatie op de dashboard.
Glossary¶
Categoriaal: Wanneer een feature categoriaal is kunnen we de waarden van die feature niet sorteren op grootte. Een feature 'haarkleur' zou bijvoorbeeld categoriaal zijn.
Correlatie: Ook wel 'lineair verband'. Wanneer twee features een lineaire relatie hebben verwacht je dat een verandering in één variabele een constante verandering in de andere variabele veroorzaakt.
Data drift: ook wel 'Model Drift'. Een model wordt getraind op een set op een bepaald moment, maar de echte wereld veranderd constant; een populatie kan bijvoorbeeld ouder worden, een ecosysteem kan opwarmen, medische techniek kan verbeteren. De mate waarin een model niet meer past op de nieuwe werkelijkheid (dus de mate waarin de test set en de huidige set van elkaar afwijken) noemen we data drift
(Data)set: Grote hoeveelheid gegevenspunten, denk bijvoorbeeld aan een excel sheet. vaak bij het trainen van een machine learning model wordt een set opgedeeld in een test- en een train set
- traintset: set met data waarop het getraind wordt
- testset: set die gebruikt wordt om het model te valideren
(Decision)tree: Ook wel 'keuzeboom'. Het is als het ware een soort flowchart (zie 'split') die wordt afgelopen om tot een predictie te komen. XGBoost bestaat uit meerdere decisiontrees die samen tot een voorspelling komen.
Feature: Een kenmerk of eigenschap, aan de hand waarvan wij een voorspelling kunnen maken. Als wij bijvoorbeeld de prijs van een huis willen voorspellen dan kan het woonoppervlak een feature daarvoor zijn.
Kolom: Verticale opeenvolging van informatie-cellen in tabel.
Leave: Als het ware een 'eindpunt' van een decision tree; de keuze waar die op land.
Machine learning: Subcategorie van artificial intelligence. Waar artificial intelligence ook een 'dom' algoritme kan zijn zoals bijvoorbeeld een NPC in een spel, duidt machine learning specifiek op een algoritme wat kan leren.
Mean: Engels voor gemiddelde; het totaal van gegevenspunten gedeeld door het aantaantal gegevenspunten.
Missing: Ontbrekende waarden worden zo aangegeven. In een set staat het ook vaak aangegeven als 'NaN' of 'None'
Model: Een algoritme dat de echte wereld probeert na te bootsen, als het ware probeert te modelleren. In ons geval hebben we een XGBoost model getrains op een hele hoop patiënten met en zonder SAD, zodat als wij een niewe patiënt aandragen het algoritme modelleert of de nieuwe patient SAD heeft.
N: Wordt vaak gebruikt om een discrete grootheid aan te geven. Bijvoorbeeld in een onderzoek waarbij 500 meetpunten genomen zijn kan er 'n = 500' staan.
Ordinaal: Wanneer een feature ordinaal is, kunnen we diens waarden sorteren naar grootte. Bijvoorbeeld een feature 'leeftijd' is ordinaal want we kunnen stellen de leeftijd 45 'groter' is dan de leeftijd 32.
Proxy: Wanneer een feature een proxy is voor een ander, dan is die feature een indirecte indicatie van de ander. Bijvoorbeeld een feature 'burgerlijke staat' zou een proxy kunnen zijn voor een feature 'leeftijd', iemands burgerlijke staat is immers een indicatie van iemands leeftijdscategorie.
Record: Hier mee wordt meestal een rij in een dataset bedoeld.
Rij: Horizontale opeenvolging van informatie-cellen in tabel.
Sample: In dit document wordt hier meestal een subset van de dataset bedoeld.
Sepsis associated delirium: Vaak afgekort als 'SAD', de aandoening die ons model probeert te voorspellen. Soms in dit notebook wordt 'SAD' gebruikt als indicatie dat iemand SAD positief is en 'NON-SAD' als de voorspelling negatief is.
SHAP waarde: Mate waarin een bepaalde feature bijdraagt aan een bepaalde voorspelling. SHAP is kort voor SHapley Additive exPlanations
Split: Binnen een decision tree maak je keuzes aan de hand van een waarde van een feature. Stel je voor we maken een flowchart om het weer te bepalen, je zou dan een keuzemoment hebben waarop je bepaald of het regent of niet; de 'feature' regen kan de 'waarde' wel of niet aannemen. Deze keuze binnen een keuzeboom noemen we een split.
Standaarddeviatie: Vaak afgekort als 'SD', een maat voor de spreiding van gegevenspunten in een dataset rondom het gemiddelde. Een hoge standaarddeviatie indiceert een hoge spreiding van de data.
Verwachtingswaarde: Ook wel gewogen gemiddelde. De waarde die een datapunt gemiddeld aanneemt. Bijvoorbeeld als we een normale dobbelsteen gooien is het gemiddelde van de ogen 3.5 dus de verwachtingswaarde is ook 3.5. Als we een gewogen dobbelsteen gooien waarbij de 6 twee keer zo vaak voorkomt, is het gemiddelde van de ogen nog steeds 3.5, maar de gemiddelde waarde ligt hoger dus de verwachtingswaarde is (1+2+3+4+5+2*6)/7 = 3.86.
X: Invoerwaarde van een functie. Wanneer het gaat om modellen die grote hoeveelheden data tegelijk verwerken kan 'X' dus ook staan voor een hele dataset die in een keer in het model gestopt wordt
XGBoost: Een machine learning model die meerdere decision trees genereert. Aan de hand van de decision trees komt het model tot diens predictie.
y: Uitvoerwaarde van een functie. In ons geval is dat dus de predictie, dus of iemand wel of geen SAD heeft.
DATA + MODELLEN ¶
data_raw = pd.read_stata('MIMIC-SAD_dta_files/MIMIC-IV.dta')
data_cf = data_raw.drop(['deliriumtime', 'hosp_mort', 'icu28dmort', 'stay_id', 'icustay', 'hospstay', 'sepsistime'], axis=1).dropna()
dummies = pd.get_dummies(data_cf['race'])
data = data_cf.drop('race',axis=1).join(dummies)
dummies = pd.get_dummies(data['first_careunit'])
data = data.drop('first_careunit',axis=1).join(dummies)
xgb_matrix = xgb.DMatrix(data.drop(['sad'], axis=1))
data_cf = data_cf.drop(['sad'], axis=1)
data_table = data_raw.drop('stay_id', axis=1)
data_table['gender'] = data_table['gender'].replace(to_replace = 0.0, value = 'FEMALE')
data_table['gender'] = data_table['gender'].replace(to_replace = 1.0, value = 'MALE')
data_table['sad'] = data_table['sad'].replace(to_replace = 0.0, value = 'NON-SAD')
data_table['sad'] = data_table['sad'].replace(to_replace = 1.0, value = 'SAD')
for i in ['vent', 'crrt', 'vaso', 'seda', 'ami', 'ckd', 'copd', 'hyperte', 'dm', 'aki', 'stroke','hosp_mort']:
data_table[i] = data_table[i].replace(to_replace = 0.0, value = 'FALSE')
data_table[i] = data_table[i].replace(to_replace = 1.0, value = 'TRUE')
data_raw: data direct uit het .dta bestand van de SAD repo
data_raw
| stay_id | age | weight | gender | race | first_careunit | temperature | heart_rate | resp_rate | spo2 | sbp | dbp | mbp | wbc | hemoglobin | platelet | bun | cr | glu | Na | Cl | K | Mg | Ca | P | inr | pt | ptt | bicarbonate | aniongap | gcs | vent | crrt | vaso | seda | sofa_score | ami | ckd | copd | hyperte | dm | sad | aki | stroke | hosp_mort | icustay | hospstay | deliriumtime | sepsistime | icu28dmort | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 30000646 | 44.0 | 79.000000 | 0 | AISAN | CCU | 37.000000 | 100.0 | 28.0 | 98.0 | 107.0 | 66.0 | 75.0 | 8.5 | 12.9 | 268.0 | 12.0 | 0.900000 | 102.0 | 138.0 | 105.0 | 3.5 | 2.2 | 7.8 | 3.4 | 1.3 | 14.500000 | 37.400002 | 25.0 | 12.0 | 15.0 | 0.0 | 0.0 | 1.0 | 0.0 | 3 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 113.0 | 200.0 | NaN | 10.0 | 0.0 |
| 1 | 30001446 | 56.0 | 119.300003 | 0 | WHITE | MICU | 36.720001 | 82.0 | 22.0 | 90.0 | 75.0 | 56.0 | 61.0 | 13.0 | 7.2 | 36.0 | 70.0 | 2.700000 | 83.0 | 128.0 | 103.0 | 4.0 | 2.2 | 7.0 | 4.5 | 2.1 | 22.400000 | 38.400002 | 15.0 | 14.0 | 15.0 | 0.0 | 0.0 | 1.0 | 0.0 | 8 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 40.0 | 147.0 | NaN | 1.0 | 0.0 |
| 2 | 30002415 | 73.0 | 83.500000 | 1 | WHITE | CVICU | 36.439999 | 71.0 | 16.0 | 100.0 | 117.0 | 67.0 | 87.0 | 6.7 | 10.4 | 96.0 | 9.0 | 0.600000 | 170.0 | 136.0 | 111.0 | 4.5 | 3.2 | NaN | NaN | 1.8 | 20.000000 | 37.700001 | 21.0 | 6.0 | 15.0 | 0.0 | 0.0 | 1.0 | 1.0 | 4 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 25.0 | 137.0 | NaN | 3.0 | 0.0 |
| 3 | 30003226 | 67.0 | 93.449997 | 0 | BLACK | SICU | 37.220001 | 89.0 | 17.0 | 98.0 | 111.0 | 63.0 | 71.0 | 8.3 | 7.3 | 225.0 | 63.0 | 18.200001 | 117.0 | 135.0 | 93.0 | 6.8 | 1.9 | 8.6 | 6.2 | NaN | NaN | NaN | 24.0 | 25.0 | 15.0 | 0.0 | 1.0 | 0.0 | 0.0 | 4 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 47.0 | 322.0 | NaN | 40.0 | 0.0 |
| 4 | 30004242 | 76.0 | 77.599998 | 1 | BLACK | TSICU | 36.720001 | 59.0 | 21.0 | 97.0 | 107.0 | 90.0 | 94.0 | 9.4 | 11.0 | 280.0 | 10.0 | 0.500000 | 123.0 | 136.0 | 100.0 | 3.3 | 1.5 | 9.1 | 3.6 | 1.0 | 11.300000 | 24.900000 | 24.0 | 15.0 | 15.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 43.0 | 182.0 | NaN | 15.0 | 0.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 14615 | 39993425 | 93.0 | 47.799999 | 1 | WHITE | MICU | 35.830002 | 115.0 | 16.0 | 99.0 | 97.0 | 71.0 | 80.0 | 6.5 | 11.6 | 119.0 | 45.0 | 0.900000 | 121.0 | 152.0 | 119.0 | 3.6 | 2.3 | 8.1 | 2.8 | 1.7 | 18.299999 | 29.500000 | 19.0 | 13.0 | 8.0 | 0.0 | 0.0 | 1.0 | 0.0 | 3 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 | 63.0 | 84.0 | 2.0 | 2.0 | 0.0 |
| 14616 | 39993476 | 67.0 | 93.000000 | 0 | WHITE | CVICU | 36.439999 | 81.0 | 16.0 | 100.0 | 112.0 | 60.0 | 78.0 | 13.1 | 13.6 | 166.0 | 12.0 | 0.700000 | 113.0 | 135.0 | 105.0 | 4.5 | 2.3 | 8.2 | 2.5 | 1.2 | 13.300000 | 23.400000 | 24.0 | 10.0 | 15.0 | 1.0 | 0.0 | 0.0 | 1.0 | 2 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 24.0 | 100.0 | NaN | 13.0 | 0.0 |
| 14617 | 39993968 | 91.0 | 57.500000 | 1 | WHITE | CCU | 35.830002 | 43.0 | 17.0 | 100.0 | 78.0 | 39.0 | 49.0 | 15.8 | 14.7 | 258.0 | 33.0 | 1.500000 | 177.0 | 132.0 | 98.0 | 5.3 | 2.0 | 8.8 | 5.1 | 1.0 | 12.400000 | 26.400000 | 23.0 | 16.0 | 15.0 | 0.0 | 0.0 | 1.0 | 0.0 | 4 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | 0.0 | 142.0 | 143.0 | 25.0 | 3.0 | 0.0 |
| 14618 | 39996044 | 59.0 | 66.400002 | 0 | WHITE | MICU | 36.389999 | 105.0 | 23.0 | 100.0 | 107.0 | 63.0 | 80.0 | 3.7 | 8.0 | 15.0 | 20.0 | 0.500000 | 161.0 | 139.0 | 105.0 | 4.0 | 1.9 | 7.6 | 4.5 | 1.3 | 13.900000 | 26.100000 | 26.0 | 12.0 | 15.0 | 1.0 | 0.0 | 1.0 | 0.0 | 3 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 169.0 | 741.0 | 15.0 | 3.0 | 0.0 |
| 14619 | 39999301 | 78.0 | 107.699997 | 0 | BLACK | CVICU | 36.610001 | 58.0 | 15.0 | 96.0 | 108.0 | 62.0 | 73.0 | 9.3 | 12.5 | 197.0 | 17.0 | 1.500000 | 114.0 | 142.0 | 109.0 | 3.4 | 2.1 | 8.8 | 3.4 | 1.1 | 13.400000 | 26.500000 | 24.0 | 12.0 | 15.0 | 0.0 | 0.0 | 0.0 | 1.0 | 2 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 28.0 | 25.0 | NaN | 24.0 | 1.0 |
14620 rows × 50 columns
data_cf: data voor counterfactual modellen (update: deprecated, counterfactuals debruikt nu data) (update 2: deze wordt nu wel gebruikt voor outlier detection)
- geen one-hot encoding
- NaN rijen gedropt
- ongebruikte features gedropt ('hosp_mort' etc.)
- target ('sad') gedropt
data_cf
| age | weight | gender | race | first_careunit | temperature | heart_rate | resp_rate | spo2 | sbp | dbp | mbp | wbc | hemoglobin | platelet | bun | cr | glu | Na | Cl | K | Mg | Ca | P | inr | pt | ptt | bicarbonate | aniongap | gcs | vent | crrt | vaso | seda | sofa_score | ami | ckd | copd | hyperte | dm | aki | stroke | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 44.0 | 79.000000 | 0 | AISAN | CCU | 37.000000 | 100.0 | 28.0 | 98.0 | 107.0 | 66.0 | 75.0 | 8.500000 | 12.9 | 268.0 | 12.0 | 0.9 | 102.0 | 138.0 | 105.0 | 3.5 | 2.2 | 7.8 | 3.4 | 1.3 | 14.500000 | 37.400002 | 25.0 | 12.0 | 15.0 | 0.0 | 0.0 | 1.0 | 0.0 | 3 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 1 | 56.0 | 119.300003 | 0 | WHITE | MICU | 36.720001 | 82.0 | 22.0 | 90.0 | 75.0 | 56.0 | 61.0 | 13.000000 | 7.2 | 36.0 | 70.0 | 2.7 | 83.0 | 128.0 | 103.0 | 4.0 | 2.2 | 7.0 | 4.5 | 2.1 | 22.400000 | 38.400002 | 15.0 | 14.0 | 15.0 | 0.0 | 0.0 | 1.0 | 0.0 | 8 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 |
| 4 | 76.0 | 77.599998 | 1 | BLACK | TSICU | 36.720001 | 59.0 | 21.0 | 97.0 | 107.0 | 90.0 | 94.0 | 9.400000 | 11.0 | 280.0 | 10.0 | 0.5 | 123.0 | 136.0 | 100.0 | 3.3 | 1.5 | 9.1 | 3.6 | 1.0 | 11.300000 | 24.900000 | 24.0 | 15.0 | 15.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 |
| 5 | 83.0 | 72.000000 | 0 | WHITE | SICU | 36.330002 | 109.0 | 16.0 | 100.0 | 111.0 | 63.0 | 79.0 | 4.800000 | 13.3 | 307.0 | 62.0 | 2.8 | 108.0 | 136.0 | 108.0 | 3.6 | 2.1 | 6.4 | 4.1 | 1.4 | 16.200001 | 26.900000 | 18.0 | 14.0 | 15.0 | 1.0 | 0.0 | 1.0 | 1.0 | 3 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 |
| 6 | 57.0 | 77.500000 | 0 | WHITE | MICU | 38.669998 | 101.0 | 23.0 | 99.0 | 130.0 | 84.0 | 93.0 | 17.200001 | 15.1 | 261.0 | 25.0 | 1.0 | 100.0 | 138.0 | 105.0 | 4.3 | 2.0 | 8.5 | 4.0 | 1.2 | 13.500000 | 33.799999 | 21.0 | 16.0 | 13.0 | 1.0 | 0.0 | 0.0 | 1.0 | 3 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 14615 | 93.0 | 47.799999 | 1 | WHITE | MICU | 35.830002 | 115.0 | 16.0 | 99.0 | 97.0 | 71.0 | 80.0 | 6.500000 | 11.6 | 119.0 | 45.0 | 0.9 | 121.0 | 152.0 | 119.0 | 3.6 | 2.3 | 8.1 | 2.8 | 1.7 | 18.299999 | 29.500000 | 19.0 | 13.0 | 8.0 | 0.0 | 0.0 | 1.0 | 0.0 | 3 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 14616 | 67.0 | 93.000000 | 0 | WHITE | CVICU | 36.439999 | 81.0 | 16.0 | 100.0 | 112.0 | 60.0 | 78.0 | 13.100000 | 13.6 | 166.0 | 12.0 | 0.7 | 113.0 | 135.0 | 105.0 | 4.5 | 2.3 | 8.2 | 2.5 | 1.2 | 13.300000 | 23.400000 | 24.0 | 10.0 | 15.0 | 1.0 | 0.0 | 0.0 | 1.0 | 2 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 |
| 14617 | 91.0 | 57.500000 | 1 | WHITE | CCU | 35.830002 | 43.0 | 17.0 | 100.0 | 78.0 | 39.0 | 49.0 | 15.800000 | 14.7 | 258.0 | 33.0 | 1.5 | 177.0 | 132.0 | 98.0 | 5.3 | 2.0 | 8.8 | 5.1 | 1.0 | 12.400000 | 26.400000 | 23.0 | 16.0 | 15.0 | 0.0 | 0.0 | 1.0 | 0.0 | 4 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 |
| 14618 | 59.0 | 66.400002 | 0 | WHITE | MICU | 36.389999 | 105.0 | 23.0 | 100.0 | 107.0 | 63.0 | 80.0 | 3.700000 | 8.0 | 15.0 | 20.0 | 0.5 | 161.0 | 139.0 | 105.0 | 4.0 | 1.9 | 7.6 | 4.5 | 1.3 | 13.900000 | 26.100000 | 26.0 | 12.0 | 15.0 | 1.0 | 0.0 | 1.0 | 0.0 | 3 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| 14619 | 78.0 | 107.699997 | 0 | BLACK | CVICU | 36.610001 | 58.0 | 15.0 | 96.0 | 108.0 | 62.0 | 73.0 | 9.300000 | 12.5 | 197.0 | 17.0 | 1.5 | 114.0 | 142.0 | 109.0 | 3.4 | 2.1 | 8.8 | 3.4 | 1.1 | 13.400000 | 26.500000 | 24.0 | 12.0 | 15.0 | 0.0 | 0.0 | 0.0 | 1.0 | 2 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 |
11196 rows × 42 columns
data_table: data voor TableOne en de correlatie matrix: een aantal features zijn aangepast om het aangenamer lezen te maken (bijv 1.0 vervangen voor 'TRUE', als het een boolean kolom is)
- geen one-hot encoding
- geen NaN rijen gedropt
- geen ongebruikte features gedropt, behalve 'stay_id': dit is een index; data-analyse hierop doen zou onzinnig zijn
- geen target gedropt
data_table
| age | weight | gender | race | first_careunit | temperature | heart_rate | resp_rate | spo2 | sbp | dbp | mbp | wbc | hemoglobin | platelet | bun | cr | glu | Na | Cl | K | Mg | Ca | P | inr | pt | ptt | bicarbonate | aniongap | gcs | vent | crrt | vaso | seda | sofa_score | ami | ckd | copd | hyperte | dm | sad | aki | stroke | hosp_mort | icustay | hospstay | deliriumtime | sepsistime | icu28dmort | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 44.0 | 79.000000 | FEMALE | AISAN | CCU | 37.000000 | 100.0 | 28.0 | 98.0 | 107.0 | 66.0 | 75.0 | 8.5 | 12.9 | 268.0 | 12.0 | 0.900000 | 102.0 | 138.0 | 105.0 | 3.5 | 2.2 | 7.8 | 3.4 | 1.3 | 14.500000 | 37.400002 | 25.0 | 12.0 | 15.0 | FALSE | FALSE | TRUE | FALSE | 3 | FALSE | FALSE | FALSE | FALSE | FALSE | NON-SAD | FALSE | FALSE | TRUE | 113.0 | 200.0 | NaN | 10.0 | 0.0 |
| 1 | 56.0 | 119.300003 | FEMALE | WHITE | MICU | 36.720001 | 82.0 | 22.0 | 90.0 | 75.0 | 56.0 | 61.0 | 13.0 | 7.2 | 36.0 | 70.0 | 2.700000 | 83.0 | 128.0 | 103.0 | 4.0 | 2.2 | 7.0 | 4.5 | 2.1 | 22.400000 | 38.400002 | 15.0 | 14.0 | 15.0 | FALSE | FALSE | TRUE | FALSE | 8 | FALSE | FALSE | FALSE | FALSE | FALSE | NON-SAD | TRUE | FALSE | FALSE | 40.0 | 147.0 | NaN | 1.0 | 0.0 |
| 2 | 73.0 | 83.500000 | MALE | WHITE | CVICU | 36.439999 | 71.0 | 16.0 | 100.0 | 117.0 | 67.0 | 87.0 | 6.7 | 10.4 | 96.0 | 9.0 | 0.600000 | 170.0 | 136.0 | 111.0 | 4.5 | 3.2 | NaN | NaN | 1.8 | 20.000000 | 37.700001 | 21.0 | 6.0 | 15.0 | FALSE | FALSE | TRUE | TRUE | 4 | FALSE | FALSE | FALSE | TRUE | FALSE | NON-SAD | TRUE | FALSE | FALSE | 25.0 | 137.0 | NaN | 3.0 | 0.0 |
| 3 | 67.0 | 93.449997 | FEMALE | BLACK | SICU | 37.220001 | 89.0 | 17.0 | 98.0 | 111.0 | 63.0 | 71.0 | 8.3 | 7.3 | 225.0 | 63.0 | 18.200001 | 117.0 | 135.0 | 93.0 | 6.8 | 1.9 | 8.6 | 6.2 | NaN | NaN | NaN | 24.0 | 25.0 | 15.0 | FALSE | TRUE | FALSE | FALSE | 4 | FALSE | FALSE | FALSE | FALSE | FALSE | NON-SAD | TRUE | FALSE | FALSE | 47.0 | 322.0 | NaN | 40.0 | 0.0 |
| 4 | 76.0 | 77.599998 | MALE | BLACK | TSICU | 36.720001 | 59.0 | 21.0 | 97.0 | 107.0 | 90.0 | 94.0 | 9.4 | 11.0 | 280.0 | 10.0 | 0.500000 | 123.0 | 136.0 | 100.0 | 3.3 | 1.5 | 9.1 | 3.6 | 1.0 | 11.300000 | 24.900000 | 24.0 | 15.0 | 15.0 | FALSE | FALSE | FALSE | FALSE | 3 | FALSE | FALSE | FALSE | TRUE | FALSE | NON-SAD | FALSE | FALSE | FALSE | 43.0 | 182.0 | NaN | 15.0 | 0.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 14615 | 93.0 | 47.799999 | MALE | WHITE | MICU | 35.830002 | 115.0 | 16.0 | 99.0 | 97.0 | 71.0 | 80.0 | 6.5 | 11.6 | 119.0 | 45.0 | 0.900000 | 121.0 | 152.0 | 119.0 | 3.6 | 2.3 | 8.1 | 2.8 | 1.7 | 18.299999 | 29.500000 | 19.0 | 13.0 | 8.0 | FALSE | FALSE | TRUE | FALSE | 3 | FALSE | FALSE | FALSE | FALSE | FALSE | SAD | FALSE | FALSE | TRUE | 63.0 | 84.0 | 2.0 | 2.0 | 0.0 |
| 14616 | 67.0 | 93.000000 | FEMALE | WHITE | CVICU | 36.439999 | 81.0 | 16.0 | 100.0 | 112.0 | 60.0 | 78.0 | 13.1 | 13.6 | 166.0 | 12.0 | 0.700000 | 113.0 | 135.0 | 105.0 | 4.5 | 2.3 | 8.2 | 2.5 | 1.2 | 13.300000 | 23.400000 | 24.0 | 10.0 | 15.0 | TRUE | FALSE | FALSE | TRUE | 2 | FALSE | FALSE | FALSE | TRUE | FALSE | NON-SAD | TRUE | FALSE | FALSE | 24.0 | 100.0 | NaN | 13.0 | 0.0 |
| 14617 | 91.0 | 57.500000 | MALE | WHITE | CCU | 35.830002 | 43.0 | 17.0 | 100.0 | 78.0 | 39.0 | 49.0 | 15.8 | 14.7 | 258.0 | 33.0 | 1.500000 | 177.0 | 132.0 | 98.0 | 5.3 | 2.0 | 8.8 | 5.1 | 1.0 | 12.400000 | 26.400000 | 23.0 | 16.0 | 15.0 | FALSE | FALSE | TRUE | FALSE | 4 | FALSE | FALSE | FALSE | FALSE | FALSE | SAD | TRUE | FALSE | FALSE | 142.0 | 143.0 | 25.0 | 3.0 | 0.0 |
| 14618 | 59.0 | 66.400002 | FEMALE | WHITE | MICU | 36.389999 | 105.0 | 23.0 | 100.0 | 107.0 | 63.0 | 80.0 | 3.7 | 8.0 | 15.0 | 20.0 | 0.500000 | 161.0 | 139.0 | 105.0 | 4.0 | 1.9 | 7.6 | 4.5 | 1.3 | 13.900000 | 26.100000 | 26.0 | 12.0 | 15.0 | TRUE | FALSE | TRUE | FALSE | 3 | FALSE | FALSE | FALSE | FALSE | FALSE | SAD | FALSE | FALSE | FALSE | 169.0 | 741.0 | 15.0 | 3.0 | 0.0 |
| 14619 | 78.0 | 107.699997 | FEMALE | BLACK | CVICU | 36.610001 | 58.0 | 15.0 | 96.0 | 108.0 | 62.0 | 73.0 | 9.3 | 12.5 | 197.0 | 17.0 | 1.500000 | 114.0 | 142.0 | 109.0 | 3.4 | 2.1 | 8.8 | 3.4 | 1.1 | 13.400000 | 26.500000 | 24.0 | 12.0 | 15.0 | FALSE | FALSE | FALSE | TRUE | 2 | FALSE | TRUE | FALSE | FALSE | FALSE | NON-SAD | TRUE | FALSE | TRUE | 28.0 | 25.0 | NaN | 24.0 | 1.0 |
14620 rows × 49 columns
data: data voor alle andere explanations
- one-hot encoding
- NaN rijen gedropt
- ongebruikte features gedropt
- geen target gedropt
data # data voor alle andere explanations: wel one-hot, wel NaN gedropt, wel ongebruikte features gedropt, geen target gedropt
| age | weight | gender | temperature | heart_rate | resp_rate | spo2 | sbp | dbp | mbp | wbc | hemoglobin | platelet | bun | cr | glu | Na | Cl | K | Mg | Ca | P | inr | pt | ptt | bicarbonate | aniongap | gcs | vent | crrt | vaso | seda | sofa_score | ami | ckd | copd | hyperte | dm | sad | aki | stroke | AISAN | BLACK | HISPANIC | OTHER | WHITE | unknown | CCU | CVICU | MICU | MICU/SICU | NICU | SICU | TSICU | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 44.0 | 79.000000 | 0 | 37.000000 | 100.0 | 28.0 | 98.0 | 107.0 | 66.0 | 75.0 | 8.500000 | 12.9 | 268.0 | 12.0 | 0.9 | 102.0 | 138.0 | 105.0 | 3.5 | 2.2 | 7.8 | 3.4 | 1.3 | 14.500000 | 37.400002 | 25.0 | 12.0 | 15.0 | 0.0 | 0.0 | 1.0 | 0.0 | 3 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
| 1 | 56.0 | 119.300003 | 0 | 36.720001 | 82.0 | 22.0 | 90.0 | 75.0 | 56.0 | 61.0 | 13.000000 | 7.2 | 36.0 | 70.0 | 2.7 | 83.0 | 128.0 | 103.0 | 4.0 | 2.2 | 7.0 | 4.5 | 2.1 | 22.400000 | 38.400002 | 15.0 | 14.0 | 15.0 | 0.0 | 0.0 | 1.0 | 0.0 | 8 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
| 4 | 76.0 | 77.599998 | 1 | 36.720001 | 59.0 | 21.0 | 97.0 | 107.0 | 90.0 | 94.0 | 9.400000 | 11.0 | 280.0 | 10.0 | 0.5 | 123.0 | 136.0 | 100.0 | 3.3 | 1.5 | 9.1 | 3.6 | 1.0 | 11.300000 | 24.900000 | 24.0 | 15.0 | 15.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |
| 5 | 83.0 | 72.000000 | 0 | 36.330002 | 109.0 | 16.0 | 100.0 | 111.0 | 63.0 | 79.0 | 4.800000 | 13.3 | 307.0 | 62.0 | 2.8 | 108.0 | 136.0 | 108.0 | 3.6 | 2.1 | 6.4 | 4.1 | 1.4 | 16.200001 | 26.900000 | 18.0 | 14.0 | 15.0 | 1.0 | 0.0 | 1.0 | 1.0 | 3 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
| 6 | 57.0 | 77.500000 | 0 | 38.669998 | 101.0 | 23.0 | 99.0 | 130.0 | 84.0 | 93.0 | 17.200001 | 15.1 | 261.0 | 25.0 | 1.0 | 100.0 | 138.0 | 105.0 | 4.3 | 2.0 | 8.5 | 4.0 | 1.2 | 13.500000 | 33.799999 | 21.0 | 16.0 | 13.0 | 1.0 | 0.0 | 0.0 | 1.0 | 3 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 14615 | 93.0 | 47.799999 | 1 | 35.830002 | 115.0 | 16.0 | 99.0 | 97.0 | 71.0 | 80.0 | 6.500000 | 11.6 | 119.0 | 45.0 | 0.9 | 121.0 | 152.0 | 119.0 | 3.6 | 2.3 | 8.1 | 2.8 | 1.7 | 18.299999 | 29.500000 | 19.0 | 13.0 | 8.0 | 0.0 | 0.0 | 1.0 | 0.0 | 3 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
| 14616 | 67.0 | 93.000000 | 0 | 36.439999 | 81.0 | 16.0 | 100.0 | 112.0 | 60.0 | 78.0 | 13.100000 | 13.6 | 166.0 | 12.0 | 0.7 | 113.0 | 135.0 | 105.0 | 4.5 | 2.3 | 8.2 | 2.5 | 1.2 | 13.300000 | 23.400000 | 24.0 | 10.0 | 15.0 | 1.0 | 0.0 | 0.0 | 1.0 | 2 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 |
| 14617 | 91.0 | 57.500000 | 1 | 35.830002 | 43.0 | 17.0 | 100.0 | 78.0 | 39.0 | 49.0 | 15.800000 | 14.7 | 258.0 | 33.0 | 1.5 | 177.0 | 132.0 | 98.0 | 5.3 | 2.0 | 8.8 | 5.1 | 1.0 | 12.400000 | 26.400000 | 23.0 | 16.0 | 15.0 | 0.0 | 0.0 | 1.0 | 0.0 | 4 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
| 14618 | 59.0 | 66.400002 | 0 | 36.389999 | 105.0 | 23.0 | 100.0 | 107.0 | 63.0 | 80.0 | 3.700000 | 8.0 | 15.0 | 20.0 | 0.5 | 161.0 | 139.0 | 105.0 | 4.0 | 1.9 | 7.6 | 4.5 | 1.3 | 13.900000 | 26.100000 | 26.0 | 12.0 | 15.0 | 1.0 | 0.0 | 1.0 | 0.0 | 3 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
| 14619 | 78.0 | 107.699997 | 0 | 36.610001 | 58.0 | 15.0 | 96.0 | 108.0 | 62.0 | 73.0 | 9.300000 | 12.5 | 197.0 | 17.0 | 1.5 | 114.0 | 142.0 | 109.0 | 3.4 | 2.1 | 8.8 | 3.4 | 1.1 | 13.400000 | 26.500000 | 24.0 | 12.0 | 15.0 | 0.0 | 0.0 | 0.0 | 1.0 | 2 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 |
11196 rows × 54 columns
model = pickle.load(open("xgb.pkl", "rb"))
cf_random = pickle.load(open("cf_random.pkl", "rb"))
cf_genetic = pickle.load(open("cf_genetic.pkl", "rb"))
cf_kdtree = pickle.load(open("cf_kdtree.pkl", "rb"))
explainer = shap.TreeExplainer(model)
explainer_waterfall = shap.Explainer(model, data) # het lijkt er op dat we beter de waterfall plot niet kunnen gebruiken, in dat geval zou deze dus ook niet nodig zijn
shap_values = explainer.shap_values(data.loc[:, ~data.columns.isin(["sad"])])
shap_waterfall = explainer_waterfall(data)
94%|=================== | 10569/11196 [00:12<00:00]
GLOBAL EXPLANATIONS ¶
Sommige explanations in deze categorie (confusion matrix, Table1) kunnen meerdere keren in het dashboard: een keer voor de oorspronkelijke (test)set, en een keer voor de set die op dat moment aan het dashboard gekoppeld is.
TableOne ¶
Gebaseerd op de Table1 package van de programmeer taal R. Deze tabel geeft een algemeen overzicht van de beschikbare data.
Hoe je dit afleest: In de linker kolom staat om welke variabele het gaat. Hierbinnen heb je twee soorten variabelen:
- ordinale variabelen: De eerste waarde in de drie laatste kolommen geeft het gemiddelde aan binnen de categorie van die kolom. De tweede waarde, die tussen haakjes, geeft de standaarddeviatie (SD) weer. Bijvoorbeeld 'age' heeft over de gehele dataset een gemiddelde van 66.9 en een SD van 15.9. bij mensen die SAD hebben is de gemiddelde 'age' 67.3 met een SD van 16.1.
- categoriale variabelen: De tweede kolom specificeert welke waarde deze variabele aanneemt. De linker waarde geeft aan hoe vaak deze waarde voorkomt binnen de respectievelijke kolom, de rechter waarde laat zien welk percentage dit is van de volledige set. Bijvoorbeeld 'gender' neemt binnen de hele set 8518 de waarde 'FEMALE' aan, dit is 58.3% van de hele set. Hier tegenover staat dat er 6102 'MALE' zijn, dus 41.7%. Binnen de SAD patienten is 57.6% 'MALE' en 42.4% 'FEMALE'
De kolom 'missing' geeft aan hoe vaak de variabele geen waarde aanneemt binnen de set. De rij 'n' is niet een variabele, dit gaat over het totaal aantal meetpunten binnen de categorie van de kolom.
De rijen 'deliriumtime', 'hosp_mort', 'icu28dmort', 'icustay', 'hospstay' en 'sepsistime' zijn geen features in het model, deze informatie is immers niet aanwezig op het moment van voorspellen; ze staan in de tabel omdat dit wel handige informatie is.
categorical = ['gender', 'race', 'first_careunit', 'vent', 'crrt', 'vaso', 'seda', 'ami', 'ckd', 'copd', 'hyperte', 'dm', 'sad', 'aki', 'stroke','hosp_mort']
groupby = ['sad']
table1 = TableOne(data_table, categorical=categorical, groupby=groupby, pval=False)
print(table1.tabulate(tablefmt = "fancy_grid")) # je kan "fancy_grid" vervangen voor "html" als dit makkelijker is voor de dashboard
╒═════════════════════════╤═══════════╤═══════════╤═══════════════╤═══════════════╤═══════════════╕ │ │ │ Missing │ Overall │ NON-SAD │ SAD │ ╞═════════════════════════╪═══════════╪═══════════╪═══════════════╪═══════════════╪═══════════════╡ │ n │ │ │ 14620 │ 9230 │ 5390 │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ age, mean (SD) │ │ 0 │ 66.9 (15.9) │ 66.7 (15.8) │ 67.3 (16.1) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ weight, mean (SD) │ │ 160 │ 83.1 (23.6) │ 83.1 (23.0) │ 83.2 (24.6) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ gender, n (%) │ FEMALE │ 0 │ 8518 (58.3) │ 5416 (58.7) │ 3102 (57.6) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ MALE │ │ 6102 (41.7) │ 3814 (41.3) │ 2288 (42.4) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ race, n (%) │ AISAN │ 0 │ 426 (2.9) │ 311 (3.4) │ 115 (2.1) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ BLACK │ │ 1266 (8.7) │ 766 (8.3) │ 500 (9.3) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ HISPANIC │ │ 557 (3.8) │ 360 (3.9) │ 197 (3.7) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ OTHER │ │ 642 (4.4) │ 417 (4.5) │ 225 (4.2) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ WHITE │ │ 9723 (66.5) │ 6372 (69.0) │ 3351 (62.2) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ unknown │ │ 2006 (13.7) │ 1004 (10.9) │ 1002 (18.6) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ first_careunit, n (%) │ CCU │ 0 │ 1366 (9.3) │ 881 (9.5) │ 485 (9.0) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ CVICU │ │ 3461 (23.7) │ 2772 (30.0) │ 689 (12.8) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ MICU │ │ 3078 (21.1) │ 1601 (17.3) │ 1477 (27.4) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ MICU/SICU │ │ 2706 (18.5) │ 1780 (19.3) │ 926 (17.2) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ NICU │ │ 534 (3.7) │ 234 (2.5) │ 300 (5.6) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ SICU │ │ 1887 (12.9) │ 1112 (12.0) │ 775 (14.4) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ TSICU │ │ 1588 (10.9) │ 850 (9.2) │ 738 (13.7) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ temperature, mean (SD) │ │ 48 │ 36.7 (0.8) │ 36.7 (0.8) │ 36.8 (0.9) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ heart_rate, mean (SD) │ │ 1 │ 89.7 (20.3) │ 88.2 (19.6) │ 92.3 (21.1) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ resp_rate, mean (SD) │ │ 24 │ 19.6 (6.0) │ 19.1 (5.9) │ 20.6 (6.1) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ spo2, mean (SD) │ │ 5 │ 97.1 (4.0) │ 97.3 (3.7) │ 96.7 (4.3) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ sbp, mean (SD) │ │ 6 │ 120.3 (23.9) │ 119.6 (23.1) │ 121.4 (25.1) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ dbp, mean (SD) │ │ 15 │ 66.5 (17.7) │ 65.7 (16.8) │ 67.9 (19.0) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ mbp, mean (SD) │ │ 14 │ 81.5 (17.8) │ 81.0 (17.0) │ 82.5 (19.0) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ wbc, mean (SD) │ │ 115 │ 13.1 (8.1) │ 12.8 (7.8) │ 13.7 (8.4) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ hemoglobin, mean (SD) │ │ 96 │ 10.3 (2.2) │ 10.3 (2.1) │ 10.5 (2.3) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ platelet, mean (SD) │ │ 102 │ 191.5 (106.0) │ 190.7 (105.3) │ 192.9 (107.1) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ bun, mean (SD) │ │ 60 │ 28.2 (22.9) │ 26.1 (21.0) │ 31.7 (25.6) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ cr, mean (SD) │ │ 56 │ 1.5 (1.5) │ 1.4 (1.5) │ 1.6 (1.6) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ glu, mean (SD) │ │ 66 │ 150.2 (74.4) │ 144.9 (66.5) │ 159.3 (85.5) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ Na, mean (SD) │ │ 50 │ 137.4 (5.5) │ 136.9 (5.0) │ 138.2 (6.2) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ Cl, mean (SD) │ │ 51 │ 103.8 (6.7) │ 103.8 (6.3) │ 103.9 (7.3) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ K, mean (SD) │ │ 59 │ 4.3 (0.8) │ 4.3 (0.8) │ 4.3 (0.9) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ Mg, mean (SD) │ │ 608 │ 2.0 (0.5) │ 2.0 (0.5) │ 2.0 (0.5) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ Ca, mean (SD) │ │ 1399 │ 8.2 (0.9) │ 8.2 (0.8) │ 8.2 (0.9) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ P, mean (SD) │ │ 1341 │ 3.8 (1.5) │ 3.7 (1.3) │ 4.0 (1.7) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ inr, mean (SD) │ │ 1609 │ 1.5 (0.8) │ 1.5 (0.7) │ 1.6 (0.8) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ pt, mean (SD) │ │ 1578 │ 17.0 (9.8) │ 16.7 (8.8) │ 17.5 (11.3) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ ptt, mean (SD) │ │ 1658 │ 37.8 (22.4) │ 37.2 (21.4) │ 39.0 (24.0) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ bicarbonate, mean (SD) │ │ 57 │ 22.2 (4.6) │ 22.5 (4.3) │ 21.7 (5.1) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ aniongap, mean (SD) │ │ 64 │ 15.0 (4.6) │ 14.4 (4.2) │ 16.0 (4.9) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ gcs, mean (SD) │ │ 2 │ 14.2 (2.4) │ 14.3 (2.4) │ 14.1 (2.3) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ vent, n (%) │ FALSE │ 0 │ 8023 (54.9) │ 5974 (64.7) │ 2049 (38.0) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ TRUE │ │ 6597 (45.1) │ 3256 (35.3) │ 3341 (62.0) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ crrt, n (%) │ FALSE │ 0 │ 14364 (98.2) │ 9152 (99.2) │ 5212 (96.7) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ TRUE │ │ 256 (1.8) │ 78 (0.8) │ 178 (3.3) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ vaso, n (%) │ FALSE │ 0 │ 7482 (51.2) │ 4992 (54.1) │ 2490 (46.2) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ TRUE │ │ 7138 (48.8) │ 4238 (45.9) │ 2900 (53.8) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ seda, n (%) │ FALSE │ 0 │ 7962 (54.5) │ 4968 (53.8) │ 2994 (55.5) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ TRUE │ │ 6658 (45.5) │ 4262 (46.2) │ 2396 (44.5) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ sofa_score, mean (SD) │ │ 0 │ 3.6 (1.9) │ 3.4 (1.7) │ 3.9 (2.2) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ ami, n (%) │ FALSE │ 0 │ 12976 (88.8) │ 8293 (89.8) │ 4683 (86.9) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ TRUE │ │ 1644 (11.2) │ 937 (10.2) │ 707 (13.1) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ ckd, n (%) │ FALSE │ 0 │ 11680 (79.9) │ 7417 (80.4) │ 4263 (79.1) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ TRUE │ │ 2940 (20.1) │ 1813 (19.6) │ 1127 (20.9) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ copd, n (%) │ FALSE │ 0 │ 14088 (96.4) │ 8944 (96.9) │ 5144 (95.4) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ TRUE │ │ 532 (3.6) │ 286 (3.1) │ 246 (4.6) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ hyperte, n (%) │ FALSE │ 0 │ 8322 (56.9) │ 5158 (55.9) │ 3164 (58.7) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ TRUE │ │ 6298 (43.1) │ 4072 (44.1) │ 2226 (41.3) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ dm, n (%) │ FALSE │ 0 │ 11962 (81.8) │ 7477 (81.0) │ 4485 (83.2) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ TRUE │ │ 2658 (18.2) │ 1753 (19.0) │ 905 (16.8) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ sad, n (%) │ NON-SAD │ 0 │ 9230 (63.1) │ 9230 (100.0) │ │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ SAD │ │ 5390 (36.9) │ │ 5390 (100.0) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ aki, n (%) │ FALSE │ 0 │ 6462 (44.2) │ 4541 (49.2) │ 1921 (35.6) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ TRUE │ │ 8158 (55.8) │ 4689 (50.8) │ 3469 (64.4) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ stroke, n (%) │ FALSE │ 0 │ 13479 (92.2) │ 8777 (95.1) │ 4702 (87.2) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ TRUE │ │ 1141 (7.8) │ 453 (4.9) │ 688 (12.8) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ hosp_mort, n (%) │ FALSE │ 0 │ 12768 (87.3) │ 8527 (92.4) │ 4241 (78.7) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ │ TRUE │ │ 1852 (12.7) │ 703 (7.6) │ 1149 (21.3) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ icustay, mean (SD) │ │ 0 │ 125.7 (147.7) │ 83.8 (92.8) │ 197.4 (190.4) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ hospstay, mean (SD) │ │ 0 │ 311.0 (307.2) │ 256.1 (247.8) │ 404.9 (369.9) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ deliriumtime, mean (SD) │ │ 9230 │ 44.9 (63.5) │ nan (nan) │ 44.9 (63.5) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ sepsistime, mean (SD) │ │ 0 │ 8.2 (20.6) │ 9.1 (23.0) │ 6.7 (15.5) │ ├─────────────────────────┼───────────┼───────────┼───────────────┼───────────────┼───────────────┤ │ icu28dmort, mean (SD) │ │ 0 │ 0.1 (0.3) │ 0.0 (0.2) │ 0.1 (0.3) │ ╘═════════════════════════╧═══════════╧═══════════╧═══════════════╧═══════════════╧═══════════════╛
Correlation Matrix ¶
Deze matrix geeft de mate van lineaire interactie tussen variabelen weer. Als de correlatie tussen twee variabelen te hoog is kan je je dus afvragen of de een een proxy is voor de ander.
corr = data_table.corr(method='pearson') #method kan aangepast worden naar 'kendall' of 'spearman', zou leuk zijn als dit interactief kan op de dashboard
mask = np.triu(np.ones_like(corr, dtype=bool))
f, ax = plt.subplots(figsize=(20, 16))
cmap = sns.color_palette("viridis_r", as_cmap=True)
sns.heatmap(
corr,
mask=mask,
cmap=cmap,
vmax=.3,
center=0,
square=True,
linewidths=.5,
cbar_kws={"shrink": .5}
)
<Axes: >
Confusion Matrix ¶
Deze matrix geeft aan hoe vaak het model een correcte/incorrecte voorspelling doet. Bijv. in dit geval kan je aflezen dat in de 3,845 gevallen van SAD, het model 1,065 keer een incorrecte voorspelling en 2,780 keer een correcte voorspelling gedaan heeft.
Als er een matrix van de train set en de huidige set te zien is kan het verschil geïnterpreteerd worden als een mate van data drift.
xgb_matrix_full = xgb.DMatrix(data.loc[:, ~data.columns.isin(["sad"])], label=data["sad"])
# het idee is dat deze er twee keer in staat: een keer met de oorspronkelijke testset en een keer met de set die op dat moment aan de app gekoppeld zit
xgb_pred_prob = model.predict(xgb_matrix_full) # dit is voor de volledige set. vervang `xgb_matrix_full` als je de performance van het model wrt een andere set wilt.
xgb_pred = np.where(xgb_pred_prob > 0.5, 1, 0)
xgb_pred_factor = pd.factorize(xgb_pred)[0]
test_sad_factor = pd.factorize(data["sad"])[0]
confusion_matrix = metrics.confusion_matrix(xgb_pred_factor, test_sad_factor)
print(confusion_matrix)
cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix = confusion_matrix, display_labels = [False, True])
cm_display.plot()
plt.show()
[[5656 1695] [1065 2780]]
XGB Functionality ¶
Feature importance ¶
De volgende grafieken duiden aan in hoe verre bepaalde features invloed hebben op het model, op basis van:
gain: de gemiddelde informatiewinst (relatieve entropie) van de splits waarin de respectievelijke feature voorkomt. Een 'split' kan gezien worden als een 'keuze' binnen de keuzeboom.weight: het totaal aantal splits waarin de respectievelijke feature voorkomt in alle trees.cover: gemiddelde van de 'coverage' van splits waarin de respectievelijke feature voorkomt, waarbij coverage is gedefinieerd als het aantal voorspellingen die beinvloed worden door de die split.
xgb.plot_importance(model, importance_type='gain', max_num_features=30)
plt.title('feature importance: gain')
plt.show()
xgb.plot_importance(model, max_num_features=30)
plt.title('feature importance: weight')
plt.show()
xgb.plot_importance(model, importance_type='cover', max_num_features=30)
plt.title('feature importance: cover')
plt.show()
Graphviz tree visualiser ¶
Een graaf visualisatie van één tree in het XGBoost model; het model bestaat uit een groot aantal van dit soort keuzebomen. In het voorbeeld staat een visualisatie van de meest representatieve boom, maar andere trees kunnen ook worden weergegeven.
xgb.to_graphviz(model, num_trees=model.best_iteration) # je kan ook andere trees visualiseren door `model.best_iteration` te vervangen voor een getal (int), zou top zijn als dit interactief kan op het dashboard
SHAP ¶
Uit de SHAP documentatie:
SHAP (SHapley Additive exPlanations) is a game theoretic approach to explain the output of any machine learning model. It connects optimal credit allocation with local explanations using the classic Shapley values from game theory and their related extensions.
Kort gezegd: bij elke voorspelling kunnen we een Shapley-waarde voor iedere feature bepalen, deze waarde is een indicatie voor de mate waarin de feature heeft bijgedragen aan deze individuele voorspelling.
wat kan ik met shap waarde
Interactive Plot ¶
Deze plot laat een sample uit de data (in het voorbeeld n = 500) zien over de twee gekozen assen. De plot lijkt aanvankelijk enigszins intimiderend, maar met wat intuïtie komt men er wel uit.
De dropdown aan de linkerkant van de grafiek bepaald de y-as, er kan gekozen worden voor:
- f(x): de cumulatieve SHAP waarde
- feature effects: de SHAP waarde van die specifieke feature
De dropdown boven de grafiek bepaald het groeperen en sorteren van de samples over de x-as. De eerste drie opties (similarity, output value, original) groeperen niet maar sorteren op de aangegeven volgorde. De feature specifieke opties groeperen de samples op elke unieke waarde van deze feature, en sorteren vervolgens deze groepen op volgorde van diens waarde in deze feature.
Rood in in de grafiek duidt op een positieve bijdrage aan de uitkomst en blauw een negatieve. Er kan ook op een plotpunt geklikt worden, dan geeft het de index van de betreffende patient, of de index van één van de patiënten binnen een groep wanneer voor de x-as een feature geselecteerd is.
Enkele voorbeelden:
- Als we de x-as op 'sofa_score' zetten en de y-as op 'f(x)' dan zien we dat, binnen deze sample, de kans het grootst is op SAD bij een sofa score van 10. Wanneer we bij de '10' hoveren met de muis zien we dat de groep met deze score 3 groot is. Ook zien we dan de gemiddelden van een aantal features binnen deze groep, en dat 'vent' het meest positief bijdraagt en 'gcs' het meest negatief.
- Als we de x-as op 'sample order by output value' en de y-as op 'vent effects' dan zien we dat patiënten met een een lagere outputwaarde (dus mensen die geen SAD hebben) minder vaak aan de beademing zitten.
# kan zeker in het data-analysten scherm
shap.force_plot(
explainer.expected_value,
shap_values[:500, :],
data.loc[:, ~data.columns.isin(["sad"])].iloc[:500, :],
plot_cmap=["#FDE725", "#440154"]
)
# `500` kan aangepast worden voor een groter/kleiner sample, let wel dat een te grote sample veel lag meebrengt
# er kan ook een lijst met indices worden meegegeven als we een specifiek groepje patiënten willen vergelijken
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Summary plot - Bar¶
Deze plot is vergelijkbaar met de feature importance plots, maar dan dus op basis van de gemiddelde SHAP waarden van de feature
shap.summary_plot(shap_values, data.loc[:, ~data.columns.isin(["sad"])], plot_type="bar")
Summary plot - Beeswarm ¶
De 'beeswarms' (of, sina plots) geven de dichtheid van datapunten rond bepaalde SHAP waarden van de respectievelijke feature aan. De kleur geeft de waarde van de feature zelf aan. Je kan hieruit aflezen wat voor soort waarden de features aan moeten nemen voor een grote impact op het model, en ook hoe vaak dit voor komt. Bijvoorbeeld: bij de feature 'Na' zien we bij hoge SHAP een dunne, fel rode lijn. Dit betekent dat bij een hoog 'Na' gehalte we een grote positieve invloed op de uitkomst van het model kunnen verwachten, maar dat dit niet vaak voorkomt. Aan de andere kant van de 'Na' beeswarm, net iets onder een SHAP-waarde van 0, is de sina vrij dik en blauw. Dit laat zien dat een relatief lage 'Na' waarde een licht negatieve invloed heeft op de uitkomst, maar dat dit erg vaak voor komt.
Je kan de bar summary plot zien als een compacte versie van deze plot: als we de absolute van de SHAP waarden in deze plot nemen (dus de alles aan de linker kant van de 0 lijn als het ware spiegelen), en vervolgens het gemiddelde van de beeswarms nemen, dan krijg je de bovenstaande bar plot.
shap.summary_plot(shap_values, data.loc[:, ~data.columns.isin(["sad"])], cmap=plt.get_cmap("viridis"))
Partial dependence plots ¶
Om in te zoomen op de effectiviteit kunnen we kijken naar partial dependence plots. In de onderstaande plot zien we op de x as de waarde van de betreffende feature, en op de y as de verwachtingswaarde (gewogen gemiddelde) van de voorspelling bij die x-waarde. De horizontale stippellijn is de verwachtinswaarde van de voorspellingen op de volledige dataset, en de verticale die van de betreffende feature. de lichtgekleurde staven op de achtergrond zijn een histogram van de dataset op basis van de feature.
In het voorbeeld kijken we naar de partial dependence van 'gcs'. We zien dat de verwachtingswaarde het hoogst is bij een 'gcs' tussen de 10 en 12, dat wil zeggen dat iemand het meeste kans op SAD heeft wanneer 'gcs' tussen de 10 en 12 zit.
Deze plot is verder niet gebaseerd op SHAP, maar als we de relatie tussen 'gcs' en de bijbehorende SHAP-waarden plotten, zien we hoe goed die overeenkomen.
shap.partial_dependence_plot(
"gcs", # 'gcs' kan worden vervangen voor een andere feature om die te bekijken
lambda X: model.predict(xgb.DMatrix(X.drop(['sad'], axis=1))),
data,
ice=False,
model_expected_value=True,
feature_expected_value=True,
)
Zoals al gezegd is deze plot qua functie vergelijkbaar met de vorige.
Deze plot zegt echter ook iets over de relatie van de betreffende feature met andere features. In dit voorbeeld zien we dat wanneer de SHAP waarde voor 'gcs' relatief laag is, de patient vaak aan de beademing zit, echter bij de allerhoogste waarden die 'gcs' aanneemt dit effect omgekeerd is.
shap.dependence_plot(
'gcs', # 'gcs' kan worden vervangen voor een andere feature om die te bekijken
shap_values,
data.loc[:, ~data.columns.isin(["sad"])],
cmap=plt.get_cmap("viridis")
# interaction_index=None
) # `interaction_index=None` kan in-gecomment worden om de kleur weg te halen.
# Ook kan `None` vervangen worden door een feature naam (of index) om handmatig de feature waarmee we vergelijken te kiezen.
# het zou wellicht leuk zijn als dit interactief kan op het dashboard
Nog een voorbeeld van dezelfde plot op een andere feature: we zien hier dat 'aniongap' negatief correleert met 'bicarbonate'. Dit kan een indicatie zijn dat de een een proxy is voor de ander.
shap.dependence_plot(
'aniongap', # 'gcs' kan worden vervangen voor een andere feature om die te bekijken
shap_values,
data.loc[:, ~data.columns.isin(["sad"])],
interaction_index='bicarbonate',
cmap=plt.get_cmap("viridis")
)
Data Drift ¶
Maakt duidelijk in hoeverre het model afwijkt van de huidige werkelijkheid
data_drift_dataset_report = Report(metrics=[
DatasetDriftMetric(),
DataDriftTable(),
])
# voor het voorbeeld heb ik gwn de data in tweeën gesplitst, in de code moet 'reference_data' de data zijn waar het model op is getraind en 'current_data' de set die op dat moment aan de applicatie hangt
data_drift_dataset_report.run(reference_data=data[:int(data.shape[0]/2)], current_data=data[int(data.shape[0]/2):])
data_drift_dataset_report
LOCAL EXPLANATIONS ¶
Certainty score ¶
# vervang '1' voor de index van de respectievelijke patient
score = model.predict(xgb.DMatrix(data.loc[[1], ~data.columns.isin(["sad"])], label=data["sad"]))[0]
if score < 0.5: score = 1-score
print("Certainty score: " + str(round(score*100, 2)) + "%")
Certainty score: 63.03%
Outlier detection¶
data_cf.shape
(11196, 42)
record = data_cf.iloc[1]
is_categorical = ['race', 'first_careunit', 'vent', 'ckd', 'crrt', 'copd', 'gender', 'vaso', 'hyperte', 'seda', 'dm', 'aki', 'ami', 'stroke']
fig, axs = plt.subplots(6, 7, figsize=(12, 6))
i = 0
j = 0
score = 0
for c in data_cf.columns:
axs[i,j].hist(data_cf[c], bins=20)
title = c
if c not in is_categorical:
title = title + "\n σ = " + str(round(data_cf[c].std(), 2)) + "\n |x-μ| = " + str(round(abs(record[c] - data_cf[c].mean()), 2))
score += (abs(record[c] - data_cf[c].mean()) / data_cf[c].std())
axs[i,j].set_title(title, fontsize=8)
axs[i,j].get_xaxis().set_visible(False)
axs[i,j].get_yaxis().set_visible(False)
axs[i,j].axvline(record[c], color='r', linestyle='dashed', linewidth=1)
i+=1
if (i%6)==0:
i = 0
j += 1
score /= data_cf.shape[1]
fig.suptitle("Outlier score: " + str(score))
fig.tight_layout()
plt.show()
Counterfactuals ¶
Het globale idee van een counterfactual is om te redeneren over wat er zou zijn gebeurd als bepaalde omstandigheden anders waren geweest. Er wordt een hypothetisch scenario gecreëert waarin één of meerdere inputs worden veranderd waarvoor de voorspelling anders zou zijn.
Feature select ¶
Counterfactuals worden gegenereerd waarbij in één of meerdere vooraf bepaalde features gevariëerd wordt.
In het voorbeeld wordt er gevarieerd in de features age, weight, temperature en gcs. In de eerste cell zien we de oorspronkelijke record, daarna volgen 5 counterfactuals die hierop zijn gebaseerd.
Aan de tabel met counterfactuals zijn een aantal kolommen toegevoegd:
- 'reg': uitkomst van het model, uitgaande van deze counterfactual record
- 'pred': True/False uitkomst gebaseerd op 'reg'
- 'fitness': fitness score aan de hand waarvan het algoritme bepaalt hoe dicht deze counterfactual bij het origineel zit. In de huidige versie is het de som van de euclidische afstanden tussen de features.
data.drop(['sad'], axis=1).iloc[[1]]
| age | weight | gender | temperature | heart_rate | resp_rate | spo2 | sbp | dbp | mbp | wbc | hemoglobin | platelet | bun | cr | glu | Na | Cl | K | Mg | Ca | P | inr | pt | ptt | bicarbonate | aniongap | gcs | vent | crrt | vaso | seda | sofa_score | ami | ckd | copd | hyperte | dm | aki | stroke | AISAN | BLACK | HISPANIC | OTHER | WHITE | unknown | CCU | CVICU | MICU | MICU/SICU | NICU | SICU | TSICU | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 56.0 | 119.300003 | 0 | 36.720001 | 82.0 | 22.0 | 90.0 | 75.0 | 56.0 | 61.0 | 13.0 | 7.2 | 36.0 | 70.0 | 2.7 | 83.0 | 128.0 | 103.0 | 4.0 | 2.2 | 7.0 | 4.5 | 2.1 | 22.4 | 38.400002 | 15.0 | 14.0 | 15.0 | 0.0 | 0.0 | 1.0 | 0.0 | 8 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
dummy_groupings = {'race':['AISAN', 'BLACK', 'HISPANIC', 'OTHER', 'WHITE', 'unknown'], 'first_careunit': ['CCU', 'CVICU', 'MICU', 'MICU/SICU', 'NICU', 'SICU', 'TSICU']}
use_feats = ['age', 'weight', 'temperature', 'gcs']
# `limit` bepaalt hoe vaak het genetisch algoritme doorlopen wordt. een hoge limit geeft counterfactuals die veel variëren maar minder op de oorspronkelijke patiënt lijken, een lage limit geeft cfs die minder variëren maar meer op de patiënt lijken.
# gezien het algoritme eerder convergeert wanneer er in minder features gevariëerd wordt, zal `limit` lager gezet moeten worden bij een kortere `use_feats`
# wellicht is het een idee om `limit` interactief te maken op het dashboard, zodat de gebruiker die zelf in kan stellen voor een gewenst resultaat.
cf_g = GeneticCounterfactual(data.drop(['sad'], axis=1), model, dummy_groupings, use_feats=use_feats, limit=1, population_size=data.drop(['sad'], axis=1).shape[0])
cf_g.generate(data.drop(['sad'], axis=1).iloc[[1]], 5)
| age | weight | gender | temperature | heart_rate | resp_rate | spo2 | sbp | dbp | mbp | wbc | hemoglobin | platelet | bun | cr | glu | Na | Cl | K | Mg | Ca | P | inr | pt | ptt | bicarbonate | aniongap | gcs | vent | crrt | vaso | seda | sofa_score | ami | ckd | copd | hyperte | dm | aki | stroke | AISAN | BLACK | HISPANIC | OTHER | WHITE | unknown | CCU | CVICU | MICU | MICU/SICU | NICU | SICU | TSICU | reg | pred | fitness | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 56.0 | 119.449997 | 0.0 | 37.110001 | 82.0 | 22.0 | 90.0 | 75.0 | 56.0 | 61.0 | 13.0 | 7.2 | 36.0 | 70.0 | 2.7 | 83.0 | 128.0 | 103.0 | 4.0 | 2.2 | 7.0 | 4.5 | 2.1 | 22.4 | 38.400002 | 15.0 | 14.0 | 14.0 | 0.0 | 0.0 | 1.0 | 0.0 | 8.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1 | False | False | False | False | False | 1 | False | False | False | False | False | False | 0.504692 | True | 1.083789 |
| 1 | 56.0 | 119.449997 | 0.0 | 37.389999 | 82.0 | 22.0 | 90.0 | 75.0 | 56.0 | 61.0 | 13.0 | 7.2 | 36.0 | 70.0 | 2.7 | 83.0 | 128.0 | 103.0 | 4.0 | 2.2 | 7.0 | 4.5 | 2.1 | 22.4 | 38.400002 | 15.0 | 14.0 | 14.0 | 0.0 | 0.0 | 1.0 | 0.0 | 8.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1 | False | False | False | False | False | 1 | False | False | False | False | False | False | 0.564591 | True | 1.213011 |
| 2 | 55.0 | 119.449997 | 0.0 | 37.389999 | 82.0 | 22.0 | 90.0 | 75.0 | 56.0 | 61.0 | 13.0 | 7.2 | 36.0 | 70.0 | 2.7 | 83.0 | 128.0 | 103.0 | 4.0 | 2.2 | 7.0 | 4.5 | 2.1 | 22.4 | 38.400002 | 15.0 | 14.0 | 14.0 | 0.0 | 0.0 | 1.0 | 0.0 | 8.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1 | False | False | False | False | False | 1 | False | False | False | False | False | False | 0.557112 | True | 1.572067 |
| 3 | 55.0 | 119.449997 | 0.0 | 38.439999 | 82.0 | 22.0 | 90.0 | 75.0 | 56.0 | 61.0 | 13.0 | 7.2 | 36.0 | 70.0 | 2.7 | 83.0 | 128.0 | 103.0 | 4.0 | 2.2 | 7.0 | 4.5 | 2.1 | 22.4 | 38.400002 | 15.0 | 14.0 | 14.0 | 0.0 | 0.0 | 1.0 | 0.0 | 8.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1 | False | False | False | False | False | 1 | False | False | False | False | False | False | 0.557112 | True | 2.231791 |
| 4 | 60.0 | 121.500000 | 0.0 | 37.169998 | 82.0 | 22.0 | 90.0 | 75.0 | 56.0 | 61.0 | 13.0 | 7.2 | 36.0 | 70.0 | 2.7 | 83.0 | 128.0 | 103.0 | 4.0 | 2.2 | 7.0 | 4.5 | 2.1 | 22.4 | 38.400002 | 15.0 | 14.0 | 14.0 | 0.0 | 0.0 | 1.0 | 0.0 | 8.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1 | False | False | False | False | False | 1 | False | False | False | False | False | False | 0.509901 | True | 4.694942 |
Genetic ¶
Deze methode genereert op basis van een genetisch algoritme een volledig nieuwe patiënt waarvoor het model een andere voorspelling zou doen. De fictieve patiënt wordt zo gegenereerd dat die zo min mogelijk van de daadwerkelijke patiënt afwijkt.
data.drop(['sad'], axis=1).iloc[[1]]
| age | weight | gender | temperature | heart_rate | resp_rate | spo2 | sbp | dbp | mbp | wbc | hemoglobin | platelet | bun | cr | glu | Na | Cl | K | Mg | Ca | P | inr | pt | ptt | bicarbonate | aniongap | gcs | vent | crrt | vaso | seda | sofa_score | ami | ckd | copd | hyperte | dm | aki | stroke | AISAN | BLACK | HISPANIC | OTHER | WHITE | unknown | CCU | CVICU | MICU | MICU/SICU | NICU | SICU | TSICU | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 56.0 | 119.300003 | 0 | 36.720001 | 82.0 | 22.0 | 90.0 | 75.0 | 56.0 | 61.0 | 13.0 | 7.2 | 36.0 | 70.0 | 2.7 | 83.0 | 128.0 | 103.0 | 4.0 | 2.2 | 7.0 | 4.5 | 2.1 | 22.4 | 38.400002 | 15.0 | 14.0 | 15.0 | 0.0 | 0.0 | 1.0 | 0.0 | 8 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
cf_genetic = GeneticCounterfactual(data.drop(['sad'], axis=1), model, dummy_groupings, limit=10, population_size=data.shape[0])
cf_genetic.generate(data.drop(['sad'], axis=1).iloc[[1]], 5)
| age | weight | gender | temperature | heart_rate | resp_rate | spo2 | sbp | dbp | mbp | wbc | hemoglobin | platelet | bun | cr | glu | Na | Cl | K | Mg | Ca | P | inr | pt | ptt | bicarbonate | aniongap | gcs | vent | crrt | vaso | seda | sofa_score | ami | ckd | copd | hyperte | dm | aki | stroke | AISAN | BLACK | HISPANIC | OTHER | WHITE | unknown | CCU | CVICU | MICU | MICU/SICU | NICU | SICU | TSICU | reg | pred | fitness | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 72.0 | 111.500000 | 0.0 | 36.560001 | 81.0 | 21.0 | 94.0 | 72.0 | 55.0 | 67.0 | 19.6 | 10.3 | 40.0 | 72.0 | 1.1 | 81.0 | 129.0 | 106.0 | 3.9 | 1.8 | 7.9 | 4.0 | 1.4 | 20.299999 | 37.200001 | 15.0 | 15.0 | 14.0 | 1.0 | 0.0 | 0.0 | 1.0 | 7.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | False | False | False | False | 0 | False | False | False | 1 | False | False | False | False | 0.685672 | True | 22.003765 |
| 1 | 50.0 | 126.699997 | 1.0 | 37.000000 | 84.0 | 16.0 | 97.0 | 85.0 | 55.0 | 61.0 | 12.1 | 7.6 | 36.0 | 73.0 | 1.3 | 84.0 | 130.0 | 102.0 | 3.3 | 1.4 | 9.7 | 3.2 | 1.1 | 13.600000 | 32.400002 | 20.0 | 17.0 | 15.0 | 1.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 | False | False | False | False | 1 | False | False | False | 0 | False | False | False | False | 0.703455 | True | 22.097925 |
| 2 | 50.0 | 126.699997 | 1.0 | 37.000000 | 83.0 | 16.0 | 97.0 | 85.0 | 55.0 | 61.0 | 12.1 | 7.6 | 36.0 | 74.0 | 1.3 | 84.0 | 130.0 | 94.0 | 3.3 | 2.2 | 8.4 | 3.2 | 1.1 | 17.900000 | 32.400002 | 20.0 | 17.0 | 15.0 | 1.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 1.0 | False | False | False | False | 1 | False | False | False | 0 | False | False | False | False | 0.709303 | True | 22.542367 |
| 3 | 72.0 | 124.949997 | 0.0 | 36.560001 | 90.0 | 22.0 | 97.0 | 72.0 | 55.0 | 61.0 | 12.1 | 9.6 | 31.0 | 72.0 | 2.3 | 81.0 | 123.0 | 101.0 | 3.4 | 1.8 | 8.4 | 0.9 | 1.8 | 20.299999 | 37.200001 | 20.0 | 14.0 | 15.0 | 1.0 | 0.0 | 1.0 | 0.0 | 4.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | False | False | False | False | 0 | False | False | False | 0 | False | False | False | False | 0.698935 | True | 23.325052 |
| 4 | 48.0 | 110.599998 | 1.0 | 37.000000 | 79.0 | 16.0 | 97.0 | 84.0 | 62.0 | 61.0 | 12.1 | 7.6 | 36.0 | 73.0 | 1.3 | 84.0 | 130.0 | 102.0 | 3.3 | 1.4 | 8.4 | 3.2 | 1.3 | 13.600000 | 32.400002 | 20.0 | 16.0 | 15.0 | 1.0 | 0.0 | 1.0 | 0.0 | 3.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 | False | False | False | False | 1 | False | False | False | 0 | False | False | False | False | 0.721635 | True | 23.378590 |
KDTree ¶
Deze methode zoekt in de dataset, aan de hand van een KDTree algoritme, een bestaande patiënt waarvoor het model een andere voorspelling zou doen. De nieuwe patiënt wordt zo gekozen dat die zo min mogelijk van de huidige patiënt afwijkt.
de nieuwe kolom 'dst' geeft de euclidische afstand tussen de counterfactual en het origineel aan, het is dus vergelijkbaar met 'fitness' in de vorige.
data.drop(['sad'], axis=1).iloc[[1]]
| age | weight | gender | temperature | heart_rate | resp_rate | spo2 | sbp | dbp | mbp | wbc | hemoglobin | platelet | bun | cr | glu | Na | Cl | K | Mg | Ca | P | inr | pt | ptt | bicarbonate | aniongap | gcs | vent | crrt | vaso | seda | sofa_score | ami | ckd | copd | hyperte | dm | aki | stroke | AISAN | BLACK | HISPANIC | OTHER | WHITE | unknown | CCU | CVICU | MICU | MICU/SICU | NICU | SICU | TSICU | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 56.0 | 119.300003 | 0 | 36.720001 | 82.0 | 22.0 | 90.0 | 75.0 | 56.0 | 61.0 | 13.0 | 7.2 | 36.0 | 70.0 | 2.7 | 83.0 | 128.0 | 103.0 | 4.0 | 2.2 | 7.0 | 4.5 | 2.1 | 22.4 | 38.400002 | 15.0 | 14.0 | 15.0 | 0.0 | 0.0 | 1.0 | 0.0 | 8 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 |
cf_kdtree = KDTreeCounterFactual(data.drop(['sad'], axis=1), model)
cf_kdtree.generate(data.drop(['sad'], axis=1).iloc[[1]], 5)
| age | weight | gender | temperature | heart_rate | resp_rate | spo2 | sbp | dbp | mbp | wbc | hemoglobin | platelet | bun | cr | glu | Na | Cl | K | Mg | Ca | P | inr | pt | ptt | bicarbonate | aniongap | gcs | vent | crrt | vaso | seda | sofa_score | ami | ckd | copd | hyperte | dm | aki | stroke | AISAN | BLACK | HISPANIC | OTHER | WHITE | unknown | CCU | CVICU | MICU | MICU/SICU | NICU | SICU | TSICU | reg | pred | dst | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 2912 | 62.0 | 103.000000 | 0 | 33.200001 | 78.0 | 16.0 | 97.0 | 89.0 | 51.0 | 58.0 | 10.9 | 8.1 | 55.0 | 101.0 | 8.6 | 70.0 | 138.0 | 105.0 | 4.9 | 2.9 | 7.9 | 6.6 | 2.5 | 27.200001 | 45.900002 | 17.0 | 21.0 | 4.0 | 0.0 | 0.0 | 1.0 | 0.0 | 8 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0.870129 | True | 50.399212 |
| 6348 | 69.0 | 104.000000 | 1 | 36.500000 | 93.0 | 31.0 | 93.0 | 71.0 | 43.0 | 50.0 | 0.9 | 11.5 | 37.0 | 38.0 | 3.0 | 101.0 | 136.0 | 103.0 | 6.1 | 2.3 | 8.9 | 9.7 | 1.6 | 18.100000 | 37.400002 | 15.0 | 28.0 | 12.0 | 1.0 | 1.0 | 1.0 | 1.0 | 3 | 1.0 | 1.0 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0.841014 | True | 52.857719 |
| 12476 | 54.0 | 143.500000 | 0 | 36.610001 | 82.0 | 8.0 | 98.0 | 83.0 | 42.0 | 49.0 | 7.1 | 9.5 | 36.0 | 68.0 | 2.3 | 118.0 | 135.0 | 101.0 | 5.2 | 2.3 | 8.6 | 6.0 | 3.1 | 34.799999 | 47.200001 | 26.0 | 13.0 | 12.0 | 0.0 | 0.0 | 1.0 | 0.0 | 4 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0.681989 | True | 54.418489 |
| 6652 | 62.0 | 94.500000 | 1 | 36.560001 | 92.0 | 26.0 | 99.0 | 89.0 | 43.0 | 55.0 | 6.7 | 7.2 | 67.0 | 43.0 | 2.7 | 68.0 | 133.0 | 107.0 | 4.6 | 1.9 | 7.1 | 4.2 | 2.2 | 22.600000 | 38.599998 | 17.0 | 14.0 | 15.0 | 1.0 | 0.0 | 1.0 | 1.0 | 6 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0.650931 | True | 57.073599 |
| 792 | 75.0 | 105.900002 | 0 | 36.560001 | 83.0 | 20.0 | 100.0 | 103.0 | 66.0 | 77.0 | 1.1 | 9.3 | 61.0 | 76.0 | 2.4 | 108.0 | 139.0 | 101.0 | 4.4 | 2.6 | 7.7 | 4.5 | 1.4 | 15.500000 | 21.400000 | 23.0 | 15.0 | 15.0 | 1.0 | 0.0 | 0.0 | 0.0 | 5 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0.645523 | True | 61.315624 |
SHAP ¶
Force plot ¶
Deze plot laat zien in hoe verre bepaalde features hebben bijgedragen aan de voorspelling. Een blauwe feature draagt negatief bij aan de voorspelling en een rode positief. Een langer balkje indiceert dat een feature meer bijgedragen heeft. 'f(x)' is de gemiddelde SHAP-waarde van deze voorspelling, en 'base value' is de verwachtingswaarde van alle SHAP-waarden in de set.
row = 0 # `row` is de index van de huidige patient
shap.force_plot(explainer.expected_value, shap_values[row, :], data.loc[:, ~data.columns.isin(["sad"])].iloc[row, :], plot_cmap=["#FDE725", "#440154"])
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Waterfall ¶
De waterfall-plot is eigenlijk hetzelfde als de vorige maar dan staat elke feature los, en op volgorde van importance. Dit kan wat overzichtelijker zijn.
# De SHAP-waarden in deze wijken iets af van de andere visualisaties, dit is omdat de rest is gebaseerd op `TreeExplainer` en deze `Explainer` van de SHAP library.
# Deze plot kan het best achterwege gelaten worden omdat:
# - Kleuren kunnen niet veranderd worden
# - De plot kan niet op een TreeExplainer gemaakt worden
# - De plot komt toch al erg overeen met de decision plots
# de eerste twee punten kunnen overkomen worden door buiten de library te werken, maar dit lijkt de moeite niet waard
shap.plots.waterfall(shap_waterfall[0], max_display=14)
Decision plots ¶
Deze is vergelijkbaar met de waterfall-plot. Als we van beneden naar boven de lijn volgen zien we hoe iedere feature de uitkomst beinvloedt.
De decision plot wijkt wel af van de waterfall-plot in dat we op de x-as niet de cumulatieve SHAP waarde hebben, maar de uitkomst van het model (boven 0.5 is de voorspelling SAD, en daaronder NON-SAD). De grijze lijn is het gewogen gemiddelde van alle voorspellingen op de set.
row = 0
shap.decision_plot(
explainer.expected_value,
shap_values[row, :],
data.loc[:, ~data.columns.isin(["sad"])].iloc[0, :],
link="logit",
highlight=0,
plot_color=plt.get_cmap("viridis")
)
Hier is de decision plot uitgebreid om de patiënt te vergelijken met andere patiënten naar keuze, in dit geval de eerste 5 patiënten uit de dataset
row_current = 0
rows = [0,1,2,3,4] # indices van de patienten waarmee we de huidige patient willen vergelijken
shap.decision_plot(
explainer.expected_value,
shap_values[rows, :],
data.loc[:, ~data.columns.isin(["sad"])].iloc[0, :],
link="logit",
highlight=row_current,
plot_color=plt.get_cmap("viridis")
)
We kunnen de interactieve plot ook gebruiken voor locale uitleg; hieronder de interactieve plot voor de zelfde 5 patiënten als de vorige plot.
shap.force_plot(
explainer.expected_value, shap_values[[0,1,2,3,4], :], plot_cmap=["#FDE725", "#440154"]
)
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.